package top.mingyi4cjh.cms.common.interceptor;

import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.stereotype.Component;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.HandlerInterceptor;
import top.mingyi4cjh.cms.common.configuration.AccessLimit;
import top.mingyi4cjh.cms.common.error.EmBusinessError;
import top.mingyi4cjh.cms.common.error.HandlerError;
import top.mingyi4cjh.cms.common.response.CommonReturnType;
import top.mingyi4cjh.cms.service.RedisService;

import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import static top.mingyi4cjh.cms.common.utils.Util.replaceStr;

/**
 * @author MingYi
 */
@Component
public class AccessLimitInterceptor implements HandlerInterceptor {

    @Resource
    private RedisService redisService;

    public static final String UNKNOWN_VALUE = "unknown";

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        if (handler instanceof HandlerMethod) {
            HandlerMethod handlerMethod = (HandlerMethod) handler;
            AccessLimit accessLimit = handlerMethod.getMethodAnnotation(AccessLimit.class);
            if (accessLimit == null) {
                return true;
            }

            int seconds = accessLimit.seconds();
            int maxCount = accessLimit.maxCount();

            String ip = getIpInformation(request);
            String key = ip + ":" + request.getServletPath();
            Integer count  = (Integer) redisService.get(key);
            if (count == null || count == -1) {
                redisService.set(key, 1, seconds);
                return true;
            }

            if (count < maxCount) {
                redisService.incr(key, 1);
                return true;
            }

            if (count >= maxCount) {
                HandlerError handlerError = new HandlerError(EmBusinessError.ACCESS_FREQUENTLY);
                String json = new ObjectMapper().writeValueAsString(CommonReturnType.create(handlerError, "fail"));
                response.setContentType("application/json;charset=UTF-8");
                response.getWriter().println(json);
                return false;
            }


        }

        return HandlerInterceptor.super.preHandle(request, response, handler);
    }

    private String getIpInformation(HttpServletRequest request) {
        String ip = null;
        //X-Forwarded-For：Squid 服务代理
        String ipAddresses = request.getHeader("X-Forwarded-For");
        if (ipAddresses == null || ipAddresses.length() == 0 || UNKNOWN_VALUE.equalsIgnoreCase(ipAddresses)) {
            //Proxy-Client-IP：apache 服务代理
            ipAddresses = request.getHeader("Proxy-Client-IP");
        }
        if (ipAddresses == null || ipAddresses.length() == 0 || UNKNOWN_VALUE.equalsIgnoreCase(ipAddresses)) {
            //WL-Proxy-Client-IP：weblogic 服务代理
            ipAddresses = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ipAddresses == null || ipAddresses.length() == 0 || UNKNOWN_VALUE.equalsIgnoreCase(ipAddresses)) {
            //HTTP_CLIENT_IP：有些代理服务器
            ipAddresses = request.getHeader("HTTP_CLIENT_IP");
        }
        if (ipAddresses == null || ipAddresses.length() == 0 || UNKNOWN_VALUE.equalsIgnoreCase(ipAddresses)) {
            //X-Real-IP：nginx服务代理
            ipAddresses = request.getHeader("X-Real-IP");
        }
        //有些网络通过多层代理，那么获取到的ip就会有多个，一般都是通过逗号（,）分割开来，并且第一个ip为客户端的真实IP
        if (ipAddresses != null && ipAddresses.length() != 0) {
            ip = ipAddresses.split(",")[0];
        }
        //还是不能获取到，最后再通过request.getRemoteAddr();获取
        if (ip == null || ip.length() == 0 || UNKNOWN_VALUE.equalsIgnoreCase(ipAddresses)) {
            ip = request.getRemoteAddr();
        }

        ip = replaceStr(ip);
        return ip;
    }
}
